{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# EXPLAIN WHAT MNIST IS (GIVE EXAMPLE) (SHOW GRAPH STRUCTURE) (MORE VISUALS IN GENERAL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "# MNIST\n",
    "\n",
    "In this tutorial, we will show you how to train an actual CNN model, albeit small. We will be using the old good MNIST dataset and the LeNet model, with a slight change that the sigmoid activations are replaced with ReLUs.\n",
    "\n",
    "We will use the model helper - that helps us to deal with parameter initializations naturally.\n",
    "\n",
    "First, let's import the necessities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import pyplot\n",
    "import numpy as np\n",
    "import os\n",
    "import shutil\n",
    "from IPython import display\n",
    "\n",
    "from caffe2.python import core, model_helper, net_drawer, workspace, visualize, brew\n",
    "\n",
    "# If you would like to see some really detailed initializations,\n",
    "# you can change --caffe2_log_level=0 to --caffe2_log_level=-1\n",
    "core.GlobalInit(['caffe2', '--caffe2_log_level=0'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# This section preps your image and test set in a leveldb\n",
    "current_folder = os.getcwd()\n",
    "\n",
    "data_folder = os.path.join(current_folder, 'tutorial_data', 'mnist')\n",
    "root_folder = os.path.join(current_folder, 'tutorial_files', 'mnist')\n",
    "image_file_train = os.path.join(data_folder, \"train-images-idx3-ubyte\")\n",
    "label_file_train = os.path.join(data_folder, \"train-labels-idx1-ubyte\")\n",
    "image_file_test = os.path.join(data_folder, \"t10k-images-idx3-ubyte\")\n",
    "label_file_test = os.path.join(data_folder, \"t10k-labels-idx1-ubyte\")\n",
    "\n",
    "# Get the dataset if it is missing\n",
    "def DownloadDataset(url, path):\n",
    "    import requests, zipfile, StringIO\n",
    "    print \"Downloading... \", url, \" to \", path\n",
    "    r = requests.get(url, stream=True)\n",
    "    z = zipfile.ZipFile(StringIO.StringIO(r.content))\n",
    "    z.extractall(path)\n",
    "\n",
    "def GenerateDB(image, label, name):\n",
    "    name = os.path.join(data_folder, name)\n",
    "    print 'DB: ', name\n",
    "    if not os.path.exists(name):\n",
    "        syscall = \"/usr/local/binaries/make_mnist_db --channel_first --db leveldb --image_file \" + image + \" --label_file \" + label + \" --output_file \" + name\n",
    "        print \"Creating database with: \", syscall\n",
    "        os.system(syscall)\n",
    "    else:\n",
    "        print \"Database exists already. Delete the folder if you have issues/corrupted DB, then rerun this.\"\n",
    "        if os.path.exists(os.path.join(name, \"LOCK\")):\n",
    "            print \"Deleting the pre-existing lock file\"\n",
    "            os.remove(os.path.join(name, \"LOCK\"))\n",
    "\n",
    "if not os.path.exists(data_folder):\n",
    "    os.makedirs(data_folder)\n",
    "if not os.path.exists(label_file_train):\n",
    "    DownloadDataset(\"https://s3.amazonaws.com/caffe2/datasets/mnist/mnist.zip\", data_folder)\n",
    "    \n",
    "if os.path.exists(root_folder):\n",
    "    print(\"Looks like you ran this before, so we need to cleanup those old files...\")\n",
    "    shutil.rmtree(root_folder)\n",
    "    \n",
    "os.makedirs(root_folder)\n",
    "workspace.ResetWorkspace(root_folder)\n",
    "\n",
    "# (Re)generate the levledb database (known to get corrupted...) \n",
    "GenerateDB(image_file_train, label_file_train, \"mnist-train-nchw-leveldb\")\n",
    "GenerateDB(image_file_test, label_file_test, \"mnist-test-nchw-leveldb\")\n",
    "\n",
    "    \n",
    "print(\"training data folder:\" + data_folder)\n",
    "print(\"workspace root folder:\" + root_folder)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def AddInput(model, batch_size, db, db_type):\n",
    "    # load the data\n",
    "    data_uint8, label = model.TensorProtosDBInput(\n",
    "        [], [\"data_uint8\", \"label\"], batch_size=batch_size,\n",
    "        db=db, db_type=db_type)\n",
    "    # cast the data to float\n",
    "    data = model.Cast(data_uint8, \"data\", to=core.DataType.FLOAT)\n",
    "    # scale data from [0,255] down to [0,1]\n",
    "    data = model.Scale(data, data, scale=float(1./256))\n",
    "    # don't need the gradient for the backward pass\n",
    "    data = model.StopGradient(data, data)\n",
    "    return data, label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def AddLeNetModel(model, data):\n",
    "    # Image size: 28 x 28 -> 24 x 24\n",
    "    conv1 = brew.conv(model, data, 'conv1', dim_in=1, dim_out=20, kernel=5)\n",
    "    # Image size: 24 x 24 -> 12 x 12\n",
    "    pool1 = brew.max_pool(model, conv1, 'pool1', kernel=2, stride=2)\n",
    "    # Image size: 12 x 12 -> 8 x 8\n",
    "    conv2 = brew.conv(model, pool1, 'conv2', dim_in=20, dim_out=50, kernel=5)\n",
    "    # Image size: 8 x 8 -> 4 x 4\n",
    "    pool2 = brew.max_pool(model, conv2, 'pool2', kernel=2, stride=2)\n",
    "    # 50 * 4 * 4 stands for dim_out from previous layer multiplied by the image size\n",
    "    fc3 = brew.fc(model, pool2, 'fc3', dim_in=50 * 4 * 4, dim_out=500)\n",
    "    fc3 = brew.relu(model, fc3, fc3)\n",
    "    pred = brew.fc(model, fc3, 'pred', 500, 10)\n",
    "    softmax = brew.softmax(model, pred, 'softmax')\n",
    "    return softmax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def AddAccuracy(model, softmax, label):\n",
    "    accuracy = brew.accuracy(model, [softmax, label], \"accuracy\")\n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def AddTrainingOperators(model, softmax, label):\n",
    "    # something very important happens here\n",
    "    xent = model.LabelCrossEntropy([softmax, label], 'xent')\n",
    "    # compute the expected loss\n",
    "    loss = model.AveragedLoss(xent, \"loss\")\n",
    "    # track the accuracy of the model\n",
    "    AddAccuracy(model, softmax, label)\n",
    "    # use the average loss we just computed to add gradient operators to the model\n",
    "    model.AddGradientOperators([loss])\n",
    "    # do a simple stochastic gradient descent\n",
    "    ITER = brew.iter(model, \"iter\")\n",
    "    # set the learning rate schedule\n",
    "    LR = model.LearningRate(\n",
    "        ITER, \"LR\", base_lr=-0.1, policy=\"step\", stepsize=1, gamma=0.999 )\n",
    "    # ONE is a constant value that is used in the gradient update. We only need\n",
    "    # to create it once, so it is explicitly placed in param_init_net.\n",
    "    ONE = model.param_init_net.ConstantFill([], \"ONE\", shape=[1], value=1.0)\n",
    "    # Now, for each parameter, we do the gradient updates.\n",
    "    for param in model.params:\n",
    "        # Note how we get the gradient of each parameter - CNNModelHelper keeps\n",
    "        # track of that.\n",
    "        param_grad = model.param_to_grad[param]\n",
    "        # The update is a simple weighted sum: param = param + param_grad * LR\n",
    "        model.WeightedSum([param, ONE, param_grad, LR], param)\n",
    "    # let's checkpoint every 20 iterations, which should probably be fine.\n",
    "    # you may need to delete tutorial_files/tutorial-mnist to re-run the tutorial\n",
    "    model.Checkpoint([ITER] + model.params, [],\n",
    "                   db=\"mnist_lenet_checkpoint_%05d.leveldb\",\n",
    "                   db_type=\"leveldb\", every=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "arg_scope = {\"order\": \"NCHW\"}\n",
    "train_model = model_helper.ModelHelper(name=\"mnist_train\", arg_scope=arg_scope)\n",
    "data, label = AddInput(\n",
    "    train_model, batch_size=64,\n",
    "    db=os.path.join(data_folder, 'mnist-train-nchw-leveldb'),\n",
    "    db_type='leveldb')\n",
    "softmax = AddLeNetModel(train_model, data)\n",
    "AddTrainingOperators(train_model, softmax, label)\n",
    "\n",
    "# Testing model. We will set the batch size to 100, so that the testing\n",
    "# pass is 100 iterations (10,000 images in total).\n",
    "# For the testing model, we need the data input part, the main LeNetModel\n",
    "# part, and an accuracy part. Note that init_params is set False because\n",
    "# we will be using the parameters obtained from the train model.\n",
    "test_model = model_helper.ModelHelper(\n",
    "    name=\"mnist_test\", arg_scope=arg_scope, init_params=False)\n",
    "data, label = AddInput(\n",
    "    test_model, batch_size=100,\n",
    "    db=os.path.join(data_folder, 'mnist-test-nchw-leveldb'),\n",
    "    db_type='leveldb')\n",
    "softmax = AddLeNetModel(test_model, data)\n",
    "AddAccuracy(test_model, softmax, label)\n",
    "\n",
    "# Deployment model. We simply need the main LeNetModel part.\n",
    "deploy_model = model_helper.ModelHelper(\n",
    "    name=\"mnist_deploy\", arg_scope=arg_scope, init_params=False)\n",
    "AddLeNetModel(deploy_model, \"data\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "graph = net_drawer.GetPydotGraphMinimal(\n",
    "    train_model.net.Proto().op, \"mnist\", rankdir=\"LR\", minimal_dependency=True)\n",
    "display.Image(graph.create_png(), width=800)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "with open(os.path.join(root_folder, \"train_net.pbtxt\"), 'w') as fid:\n",
    "    fid.write(str(train_model.net.Proto()))\n",
    "with open(os.path.join(root_folder, \"train_init_net.pbtxt\"), 'w') as fid:\n",
    "    fid.write(str(train_model.param_init_net.Proto()))\n",
    "with open(os.path.join(root_folder, \"test_net.pbtxt\"), 'w') as fid:\n",
    "    fid.write(str(test_model.net.Proto()))\n",
    "with open(os.path.join(root_folder, \"test_init_net.pbtxt\"), 'w') as fid:\n",
    "    fid.write(str(test_model.param_init_net.Proto()))\n",
    "with open(os.path.join(root_folder, \"deploy_net.pbtxt\"), 'w') as fid:\n",
    "    fid.write(str(deploy_model.net.Proto()))\n",
    "print(\"Protocol buffers files have been created in your root folder: \"+root_folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# The parameter initialization network only needs to be run once.\n",
    "workspace.RunNetOnce(train_model.param_init_net)\n",
    "# creating the network\n",
    "workspace.CreateNet(train_model.net)\n",
    "# set the number of iterations and track the accuracy & loss\n",
    "total_iters = 200\n",
    "accuracy = np.zeros(total_iters)\n",
    "loss = np.zeros(total_iters)\n",
    "# Now, we will manually run the network for 200 iterations. \n",
    "for i in range(total_iters):\n",
    "    workspace.RunNet(train_model.net.Proto().name)\n",
    "    accuracy[i] = workspace.FetchBlob('accuracy')\n",
    "    loss[i] = workspace.FetchBlob('loss')\n",
    "# After the execution is done, let's plot the values.\n",
    "pyplot.plot(loss, 'b')\n",
    "pyplot.plot(accuracy, 'r')\n",
    "pyplot.legend(('Loss', 'Accuracy'), loc='upper right')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# Let's look at some of the data.\n",
    "pyplot.figure()\n",
    "data = workspace.FetchBlob('data')\n",
    "_ = visualize.NCHW.ShowMultiple(data)\n",
    "pyplot.figure()\n",
    "softmax = workspace.FetchBlob('softmax')\n",
    "_ = pyplot.plot(softmax[0], 'ro')\n",
    "pyplot.title('Prediction for the first image')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "# run a test pass on the test net\n",
    "workspace.RunNetOnce(test_model.param_init_net)\n",
    "workspace.CreateNet(test_model.net)\n",
    "test_accuracy = np.zeros(100)\n",
    "for i in range(100):\n",
    "    workspace.RunNet(test_model.net.Proto().name)\n",
    "    test_accuracy[i] = workspace.FetchBlob('accuracy')\n",
    "# After the execution is done, let's plot the values.\n",
    "pyplot.plot(test_accuracy, 'r')\n",
    "pyplot.title('Acuracy over test batches.')\n",
    "print('test_accuracy: %f' % test_accuracy.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
